
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>文章结构 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="next" title="文章结构" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html" />
    <link rel="prev" title="ResNet源码解读" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.1%20ResNet.html">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第十章：常见代码解读
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Transformer%20%E8%A7%A3%E8%AF%BB.html">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ViT%E8%A7%A3%E8%AF%BB.html">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Swin-Transformer%E8%A7%A3%E8%AF%BB.html">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第十章/RNN详解及其实现.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第十章/RNN详解及其实现.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第十章/RNN详解及其实现.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#">
   文章结构
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#rnn">
   为什么需要 RNN？
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   RNN 理解及其简单实现
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id3">
   RNN 完成文本分类任务
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   RNN 存在的问题
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>文章结构</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#">
   文章结构
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#rnn">
   为什么需要 RNN？
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   RNN 理解及其简单实现
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id3">
   RNN 完成文本分类任务
  </a>
 </li>
 <li class="toc-h1 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id4">
   RNN 存在的问题
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="id1">
<h1>文章结构<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h1>
<p>提及 RNN，绝大部分人都知道他是一个用于序列任务的神经网络，会提及他保存了时序信息，但是，为什么需要考虑时序的信息？为什么说 RNN 保存了时序的信息？RNN又存在哪些问题？ 本篇内容将按照以下顺序逐步带你摸清 RNN 的细节之处，并使用 PyTorch 来完成一个自己的文本分类模型。</p>
<ol class="simple">
<li><p>为什么需要 RNN？</p></li>
<li><p>RNN 理解及其简单实现。</p></li>
<li><p>用 RNN 完成文本分类任务。</p></li>
<li><p>RNN 存在的问题。</p></li>
</ol>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="rnn">
<h1>为什么需要 RNN？<a class="headerlink" href="#rnn" title="永久链接至标题">#</a></h1>
<p>在现实生活的世界中，有很多的内容是有着前后关系的，比如你所阅读的这段文字，他并不是毫无理由的随机组合，而是我构思之后按顺序写下的一个个文字。除了文字之外，例如人的发音、物品的价格的曲线、温度变化等等，都是有着前后顺序存在的。</p>
<p>很明显，当知道了前面的信息，就可以对后面的信息进行合理的预测。比如，前十天温度都只有20度，明天的温度无论如何不可能零下；这个商品一年来价格都在30左右浮动，明天我去买他的时候，准备40就足够了；老师很好的表扬了你，紧跟着说了一个但是，你就知道他的内容要开始转折了。这就是隐藏在日常生活中的序列信息，因为已经知道了前面发生的内容，所以才可以推理后面的内容。</p>
<p>那么，可以用传统的多层感知机来处理序列问题吗？按照基本的多层感知机模型方案来实现，应该是这样的：将序列输入固定成一个 <span class="math notranslate nohighlight">\(d\)</span> 维向量，就可以送入多层感知机进行学习，形如公式：
$<span class="math notranslate nohighlight">\(
 \mathbf{H} = \phi(\mathbf{X} \mathbf{W}_{xh} + \mathbf{b}_{h})
\)</span><span class="math notranslate nohighlight">\(
 公式中， \)</span>\phi<span class="math notranslate nohighlight">\( 表示激活函数，\)</span>\mathbf{X} \in \mathbb{R}^{n \times d}<span class="math notranslate nohighlight">\( 表示一组小批量样本，其中 \)</span>n<span class="math notranslate nohighlight">\( 是样本大小， \)</span>d<span class="math notranslate nohighlight">\( 表示输入的特征维度。：\)</span>\mathbf{W}_{xh} \in \mathbb{R}^{d \times h}<span class="math notranslate nohighlight">\(表示模型权重参数，\)</span>d \in \mathbb{R}^{1 \times h}<span class="math notranslate nohighlight">\(表示模型偏置。最后可以得到隐藏层输入：\)</span>\mathbf{H} \in \mathbb{R}^{n \times h}<span class="math notranslate nohighlight">\(，其中 \)</span>h$ 表示隐藏层大小。</p>
<p>紧接着，模型可以使用下面的公式进行计算，得到输出：
$<span class="math notranslate nohighlight">\(
\mathbf{O} = \mathbf{H}\mathbf{W}_{hq} + \mathbf{b}_q
\)</span><span class="math notranslate nohighlight">\(
其中，\)</span>\mathbf{O} \in \mathbb{R}^{n \times q}<span class="math notranslate nohighlight">\( 为模型输出变量，\)</span>q<span class="math notranslate nohighlight">\( 表示输出层向量，由于本次的任务是一个文本分类任务，那这里 \)</span>q<span class="math notranslate nohighlight">\( 就表示文本类别，可以使用 \)</span>\mathbf{Softmax(O)}$ 来进行概率预测。</p>
<p>但是，上面的流程有一个很明显的前置条件：<strong>固定成 <span class="math notranslate nohighlight">\(d\)</span> 维向量</strong>，也就是说，传统的多层感知机，是不能对变长序列进行处理的。但是，<strong>在序列任务中，序列长短很明显是并不相同的</strong>，不仅需要用一天的数据预测明天的结果，也可能需要拿一年的数据预测明天的结果。在这样的情况下，如果还想要使用传统的多层感知机，就会面临着一个巨大的问题：如何将一天的内容与一年的内容变化成相同的 <span class="math notranslate nohighlight">\(d\)</span> 维向量？</p>
<p>除此之外，序列信息可能还有另外一个情况：<strong>某些信息可能出现在序列的不同位置。虽然信息出现在不同的位置，但是他可能表达出了相同的含义</strong>。</p>
<p>举例来说：当我们和老师谈话时，如果他表扬了我们半小时，然后说：&quot;但是...&quot;，我们往往是不担心的，因为他可能只是为了指出一些小问题。如果他刚刚表扬了一句话，紧接着就说“但是”，那我们就必须做好面对半小时的狂风暴雨。还有另外一种可能，老师可能连续批评你很久，然后使用“但是”转折，你就会在这时候如释重负，因为你知道这场谈话就快要结束了。这就是我们根据前文(表扬的内容和时间)，在老师说出&quot;但是&quot;的时候，所作出的判断。</p>
<p>上面提到的两个问题，使用多层感知机本身似乎难以解决，但是所幸，RNN 从一个更常规的思路出发来解决这个问题：<strong>记住之前看到的内容，并结合当前看到的内容，来预测之后可能的内容。</strong></p>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="id2">
<h1>RNN 理解及其简单实现<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h1>
<p>根据开篇的内容，相信你已经可以简单的理解为什么传统的多层感知机无法很好的解决序列信息，接下来我们开始理解，RNN 如何记忆之前的内容的。</p>
<p>在这里，我先放出 RNN 的公式，请将其与多层感知机公式进行对比：
$<span class="math notranslate nohighlight">\(
\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{xh} + \mathbf{H}_{t-1} \mathbf{W}_{hh}  + \mathbf{b}_h).
\)</span><span class="math notranslate nohighlight">\(
可以看到，与上一个公式相比，这里最明显的一点是多了一个 \)</span>\mathbf{H}<em>{t-1} \mathbf{W}</em>{hh}<span class="math notranslate nohighlight">\( ，从公式上似乎很好理解，\)</span>\mathbf{H}_{t-1}<span class="math notranslate nohighlight">\( 表示着前一时刻的隐藏状态，表示的是**之前看到的内容**，然后加上当前时刻的输入 \)</span>\mathbf{X}<em>t<span class="math notranslate nohighlight">\(，就可以输出当前时刻的隐藏结果 \)</span>\mathbf{H}</em>{t}$。在得到隐藏结果后，它就可以被用于下一步的计算。</p>
<p>当然，这个迭代过程也可能随时终止，如果将得到的隐藏结果用于输出，便可以直接得到输出结果，公式表达为：
$<span class="math notranslate nohighlight">\(
\mathbf{O} = \phi(\mathbf{H}_{t} \mathbf{W}_{hq}  + \mathbf{b}_q).
\)</span>$</p>
<p>可以看到，公式四与公式二极其相似，仅有隐藏状态 <span class="math notranslate nohighlight">\(\mathbf{H}\)</span> 略有不同。</p>
<p>此时，根据以上公式及其理解，已经可以构建一个简单的 RNN 模型了：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RNNDemo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">CharRNNClassify</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span>
        <span class="c1"># 计算隐藏状态 H</span>
        <span class="c1"># 因为要用以下一次计算，所以输出维度应该是 hidden_size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">i2h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)</span>
        <span class="c1"># 输出结果 O</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">i2o</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">input_size</span> <span class="o">+</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">output_size</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LogSoftmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">input</span><span class="p">,</span> <span class="n">hidden</span><span class="p">):</span>
        <span class="c1"># 将 X 和 Ht-1 合并</span>
        <span class="n">combined</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">((</span><span class="nb">input</span><span class="p">,</span> <span class="n">hidden</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span>
        <span class="c1"># 计算 Ht</span>
        <span class="n">hidden</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">i2h</span><span class="p">(</span><span class="n">combined</span><span class="p">)</span>
        <span class="c1"># 计算当前情况下的输出</span>
        <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">i2o</span><span class="p">(</span><span class="n">combined</span><span class="p">)</span>
        <span class="c1"># 分类任务使用 softmax 进行概率预测</span>
        <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
        <span class="c1"># 返回预测结果 和 当前的隐藏状态</span>
        <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">hidden</span>
    <span class="k">def</span> <span class="nf">initHidden</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
    	<span class="c1"># 避免随机生成的 H0 干扰后续结果</span>
        <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span>
</pre></div>
</div>
<p>辅助代码理解：根据公式三可知，<span class="math notranslate nohighlight">\(\mathbf{W}_{xh}\)</span> 和 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span> 在输入阶段时两者互不影响，所以在 <code class="docutils literal notranslate"><span class="pre">self.i2h</span> <span class="pre">=</span> <span class="pre">nn.Linear(input_size</span> <span class="pre">+</span> <span class="pre">hidden_size,</span> <span class="pre">hidden_size)</span></code> 中对输入的维度进行扩容，前 <code class="docutils literal notranslate"><span class="pre">input_size</span></code> 与公式 <span class="math notranslate nohighlight">\(\mathbf{W}_{xh}\)</span> 对应，而后面的 <code class="docutils literal notranslate"><span class="pre">hidden_size</span></code> 则是和 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span> 对应。</p>
<p>阅读代码之后，请根据代码和公式，来回忆第一部分提出的两个问题，通过回答这两个问题，就可以进一步的分析 RNN。</p>
<p>第一个问题是对于不同的序列长度，如何进行处理其向量表示：</p>
<p>从公式中可以看到，RNN 并不要求不同的序列表示成相同的维度，而是要求序列中的每一个值，表示成为相同的维度，这样，我们可以将在 <span class="math notranslate nohighlight">\(t\)</span> 时刻输入的值视为的 <span class="math notranslate nohighlight">\(\mathbf{X}_t\)</span>，并且结合之前时刻输入并计算得来的隐藏状态 <span class="math notranslate nohighlight">\(\mathbf{H}_{t-1}\)</span>，得到当前时刻的结果，这样无论序列实际长度如何，我们随时可以在想要中断的时候将隐藏状态转变成输出的结果，甚至我们可以在输入的同时，得到输出的结果。
【图片，待补充】</p>
<p>第二个问题，某些信息可能出现在序列的不同位置，但是其表达的含义是相同的：</p>
<p>对于这个问题，单独查看公式与代码可能不太好理解，但是可以从卷积神经网络中得到一定的灵感。</p>
<p>卷积神经网络具有平移等变性，也就是说输入的 <span class="math notranslate nohighlight">\(\mathbf{X}\)</span> 不会因为位置的变化而导致输出的不同，这得益于卷积核使用了参数共享，无论图片哪个位置进行输入，只要卷积核的参数不变，输入值就不变，其结果就不会发生变化。</p>
<p>扭回头来看 RNN 中，其 <span class="math notranslate nohighlight">\(\mathbf{X}\)</span>  与 <span class="math notranslate nohighlight">\(\mathbf{H}\)</span> 所使用的权重矩阵一直是一个，也就是说 <span class="math notranslate nohighlight">\(\mathbf{W}_{xh}\)</span> 和 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span> 是参数共享的，那么无论从序列的哪个位置进行输入，只要输入内容完全一样，其输出结果也就是完全一样的的。</p>
<p>在理解了 RNN 来龙去脉之后，接下来开始从 RNN 的在实际文本分类中进行更深入的分析。(注：该样例源自 <a class="reference external" href="https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html">Torch 官方教程</a>)。</p>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="id3">
<h1>RNN 完成文本分类任务<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h1>
<p>完成一个基本的算法任务，有以下流程：数据分析、数据转换、构建模型、定义训练函数、执行训练、保存模型、评估模型。</p>
<p>这里摘取官方教程中部分关键代码进行讲解，可以直接<a class="reference external" href="https://pytorch.org/tutorials/_downloads/13b143c2380f4768d9432d808ad50799/char_rnn_classification_tutorial.ipynb">点击这里</a>直接下载官方 notebook进行训练。训练所用数据位于<a class="reference external" href="https://download.pytorch.org/tutorial/data.zip">这里</a>。</p>
<p>在 notebook 第一个可执行 cell 中，首先定义了可用字符 <code class="docutils literal notranslate"><span class="pre">all_letters</span></code> 和 可用字符数量 <code class="docutils literal notranslate"><span class="pre">n_letters</span></code> 。同时，将下载的数据转为 ASCII 编码，使用  Dict l类型进行存储。其保存格式为：<code class="docutils literal notranslate"><span class="pre">{language:</span> <span class="pre">[names</span> <span class="pre">...]}</span></code>。</p>
<p>在第三个 cell 中，定义了三个方法，主要目的是将由 <span class="math notranslate nohighlight">\(n\)</span> 个字符组成的句子变成一个 <span class="math notranslate nohighlight">\(n \times d\)</span> 的向量，其中 <span class="math notranslate nohighlight">\(d\)</span> 表示字符特征，在这里使用 One-Hot 编码。由于 One-Hot 编码形为 <span class="math notranslate nohighlight">\(1 \times n\_letters\)</span>，则最终形状为 <span class="math notranslate nohighlight">\(n \times 1 \times n\_letters\)</span>。</p>
<p>第四个 Cell 中，定义了基本的 RNN 模型的代码与上方代码一致，并设置隐藏层大小为128。接下来的第五个和第六个可执行 Cell中，对于 RNN 进行简单的测试。</p>
<p>在这里，简单的讲解第六个 Cell，第六个 Cell 代码如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 将 Albert 转为 6 * 1 * n_letters 的 Tensor</span>
<span class="nb">input</span> <span class="o">=</span> <span class="n">lineToTensor</span><span class="p">(</span><span class="s1">&#39;Albert&#39;</span><span class="p">)</span>
<span class="c1"># 设置 h0 全零的原因在上面提到过</span>
<span class="n">hidden</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_hidden</span><span class="p">)</span>
<span class="c1"># 获取 output 和 h1</span>
<span class="n">output</span><span class="p">,</span> <span class="n">next_hidden</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="nb">input</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">hidden</span><span class="p">)</span>
<span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
</pre></div>
</div>
<p>第三行中 <code class="docutils literal notranslate"> <span class="pre">rnn(input[0],</span> <span class="pre">hidden)</span></code>，输入了首字母 也就是 'A' 的 One-Hot 编码，输出 <code class="docutils literal notranslate"><span class="pre">output</span></code> 是下一个字符可能是什么的概率，而 <code class="docutils literal notranslate"><span class="pre">next-hidden</span></code> 则是 用于搭配 <code class="docutils literal notranslate"><span class="pre">input[1]</span></code> 进行下一步输入训练的模型。</p>
<p>跳过七八九三个 Cell 后，我们再对 train 所在的 Cell 进行分析，下面是相关代码:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 设置学习率</span>
<span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.005</span> 
<span class="c1"># 输入参数中， categor_tensor 表示类别，用以计算 loss</span>
<span class="c1"># line_tensor 是由一句话所转变的 tensor, shape: n * 1 * n_letter</span>
<span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">category_tensor</span><span class="p">,</span> <span class="n">line_tensor</span><span class="p">):</span>
    <span class="c1"># 设置 H0</span>
    <span class="n">hidden</span> <span class="o">=</span> <span class="n">rnn</span><span class="o">.</span><span class="n">initHidden</span><span class="p">()</span>
    <span class="c1"># 梯度清零</span>
    <span class="n">rnn</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span>
    <span class="c1"># 将字符挨个输入</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">line_tensor</span><span class="o">.</span><span class="n">size</span><span class="p">()[</span><span class="mi">0</span><span class="p">]):</span>
        <span class="n">output</span><span class="p">,</span> <span class="n">hidden</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="n">line_tensor</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">hidden</span><span class="p">)</span>
    <span class="c1"># 计算损失</span>
    <span class="n">loss</span> <span class="o">=</span> <span class="n">criterion</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">category_tensor</span><span class="p">)</span>
    <span class="c1"># 梯度回传</span>
    <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span>

    <span class="c1"># Add parameters&#39; gradients to their values, multiplied by learning rate</span>
    <span class="c1"># 这里其实是一个手动的优化器 optimizer</span>
    <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">rnn</span><span class="o">.</span><span class="n">parameters</span><span class="p">():</span>
        <span class="n">p</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">add_</span><span class="p">(</span><span class="n">p</span><span class="o">.</span><span class="n">grad</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=-</span><span class="n">learning_rate</span><span class="p">)</span>

    <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
</pre></div>
</div>
<p>结合注释，进一步观察代码，可以看到，对于一个变长序列，在输入最后一个字符之前，都使用 <code class="docutils literal notranslate"><span class="pre">hidden</span></code> 作为输出用于下一步的计算，也就是将历史信息带入下一轮训练中去，而在最后一个字符输入结束后，使用 <code class="docutils literal notranslate"><span class="pre">outpt</span></code> 作为输出，进行文本分类的预测。</p>
<p>在 <code class="docutils literal notranslate"><span class="pre">train</span></code> 中的代码进行了对一句话进行了单独的训练，而实际过程中，我们要对多个句子进行训练，在示例代码中，采用随机采样法，从全部数据中随机提取一句话进行训练，并得到最终结果：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">time</span>
<span class="kn">import</span> <span class="nn">math</span>
<span class="c1"># 迭代次数</span>
<span class="n">n_iters</span> <span class="o">=</span> <span class="mi">100000</span>
<span class="c1"># 输出频率</span>
<span class="n">print_every</span> <span class="o">=</span> <span class="mi">5000</span>
<span class="c1"># loss计算频率</span>
<span class="n">plot_every</span> <span class="o">=</span> <span class="mi">1000</span>

<span class="c1"># Keep track of losses for plotting</span>
<span class="n">current_loss</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">all_losses</span> <span class="o">=</span> <span class="p">[]</span>

<span class="k">def</span> <span class="nf">timeSince</span><span class="p">(</span><span class="n">since</span><span class="p">):</span>
    <span class="n">now</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>
    <span class="n">s</span> <span class="o">=</span> <span class="n">now</span> <span class="o">-</span> <span class="n">since</span>
    <span class="n">m</span> <span class="o">=</span> <span class="n">math</span><span class="o">.</span><span class="n">floor</span><span class="p">(</span><span class="n">s</span> <span class="o">/</span> <span class="mi">60</span><span class="p">)</span>
    <span class="n">s</span> <span class="o">-=</span> <span class="n">m</span> <span class="o">*</span> <span class="mi">60</span>
    <span class="k">return</span> <span class="s1">&#39;</span><span class="si">%d</span><span class="s1">m </span><span class="si">%d</span><span class="s1">s&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="n">m</span><span class="p">,</span> <span class="n">s</span><span class="p">)</span>
<span class="c1"># 开始时间</span>
<span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span>

<span class="k">for</span> <span class="nb">iter</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">n_iters</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span>
    <span class="c1"># 使用随机采样提取一句话及其标签</span>
    <span class="n">category</span><span class="p">,</span> <span class="n">line</span><span class="p">,</span> <span class="n">category_tensor</span><span class="p">,</span> <span class="n">line_tensor</span> <span class="o">=</span> <span class="n">randomTrainingExample</span><span class="p">()</span>
    <span class="c1"># 训练</span>
    <span class="n">output</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">train</span><span class="p">(</span><span class="n">category_tensor</span><span class="p">,</span> <span class="n">line_tensor</span><span class="p">)</span>
    <span class="c1"># 计算loss</span>
    <span class="n">current_loss</span> <span class="o">+=</span> <span class="n">loss</span>

    <span class="c1"># Print iter number, loss, name and guess</span>
    <span class="c1"># 输出阶段结果</span>
    <span class="k">if</span> <span class="nb">iter</span> <span class="o">%</span> <span class="n">print_every</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="n">guess</span><span class="p">,</span> <span class="n">guess_i</span> <span class="o">=</span> <span class="n">categoryFromOutput</span><span class="p">(</span><span class="n">output</span><span class="p">)</span>
        <span class="n">correct</span> <span class="o">=</span> <span class="s1">&#39;✓&#39;</span> <span class="k">if</span> <span class="n">guess</span> <span class="o">==</span> <span class="n">category</span> <span class="k">else</span> <span class="s1">&#39;✗ (</span><span class="si">%s</span><span class="s1">)&#39;</span> <span class="o">%</span> <span class="n">category</span>
        <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;</span><span class="si">%d</span><span class="s1"> </span><span class="si">%d%%</span><span class="s1"> (</span><span class="si">%s</span><span class="s1">) </span><span class="si">%.4f</span><span class="s1"> </span><span class="si">%s</span><span class="s1"> / </span><span class="si">%s</span><span class="s1"> </span><span class="si">%s</span><span class="s1">&#39;</span> <span class="o">%</span> <span class="p">(</span><span class="nb">iter</span><span class="p">,</span> <span class="nb">iter</span> <span class="o">/</span> <span class="n">n_iters</span> <span class="o">*</span> <span class="mi">100</span><span class="p">,</span> <span class="n">timeSince</span><span class="p">(</span><span class="n">start</span><span class="p">),</span> <span class="n">loss</span><span class="p">,</span> <span class="n">line</span><span class="p">,</span> <span class="n">guess</span><span class="p">,</span> <span class="n">correct</span><span class="p">))</span>

    <span class="c1"># Add current loss avg to list of losses</span>
    <span class="c1"># 保存阶段性 loss</span>
    <span class="k">if</span> <span class="nb">iter</span> <span class="o">%</span> <span class="n">plot_every</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
        <span class="n">all_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_loss</span> <span class="o">/</span> <span class="n">plot_every</span><span class="p">)</span>
        <span class="n">current_loss</span> <span class="o">=</span> <span class="mi">0</span>
</pre></div>
</div>
<p>到这里，RNN 训练的代码可以说讲解结束了，保存以及推理在理解了训练的过程上并不算难事，所以不再进行讲解。</p>
<p>接下来再根据文本分类任务对 RNN 进行一次分析。可以看到，在本次任务中，一个单独的词作为一个序列，每个词的长短不一并不会影响 RNN 的训练过程，而序列中的值，则是字符，每个字符都构成了相同的向量： <span class="math notranslate nohighlight">\(1 \times d\)</span> ，这使得训练的过程也比较的统一。</p>
<p>再简单的举一反三，可以结合之前所学的 word2vec、Glovec 等模型将词语转为向量，将一句话转为一个序列，每个词转为序列中的一个值，这样的话，就可以对一句话进行文本分类了。</p>
</section>
<section class="tex2jax_ignore mathjax_ignore" id="id4">
<h1>RNN 存在的问题<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h1>
<p>前面讲解了 RNN 是如何解决简单神经网络无法处理序列问题的，但是 RNN 是否就完美无缺？能应用于全部的序列任务了呢？答案当然是否定的。</p>
<p>这是由于 RNN 存在一个巨大的缺陷：梯度爆炸与梯度消失。</p>
<p>重新审查代码与公式，可以很轻松的发现，在序列达到末尾时，我们才需要计算损失与进行梯度回传，此时将 <span class="math notranslate nohighlight">\(\mathbf{H}_t\)</span> 展开，其内部存在 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh} \times \mathbf{H}_{t-1}\)</span>。而将 <span class="math notranslate nohighlight">\(\mathbf{H}_{t-1}\)</span> 展开，也存在一个 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span>，那么很明显 如果 <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span> 大于 1，在经过 <span class="math notranslate nohighlight">\(t\)</span> 次连乘之后会产生梯度爆炸，如果  <span class="math notranslate nohighlight">\(\mathbf{W}_{hh}\)</span> 小于 1，在经过 <span class="math notranslate nohighlight">\(t\)</span> 次连乘之后又会产生梯度消失。同理，在 <span class="math notranslate nohighlight">\(\mathbf{W}_{xh}\)</span> 上，也存在这样的依赖关系，也会导致梯度爆炸或者消失。</p>
<p>梯度的消失也可以这样思考：前方节点的隐藏信息总和占了 0.5，而当前节点隐藏信息占了 0.5，随着序列长度的增加，位于开始部分节点的隐藏信息占比几乎是指数型下降，从 1/2 到 1/4 到 1/8，以此类推。可以参考下图尝试理解。因此，在 RNN 中，梯度消失就成了一个比梯度爆炸更关注的难题，因为这使得模型只能记得离自己最近的几个节点的状态，而很难考虑更远一步的状态，无法更好的关联上下文信息。</p>
<p>为了解决这个问题，我们可以考虑使用 梯度剪裁的方式保证梯度传递过程中不大于1，从而使得梯度不会爆炸，但是，我们无法阻止梯度的消失。当然，也可以查看下一篇，针对 RNN 进行优化的 LSTM、GRU 等算法</p>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">ResNet源码解读</p>
        </div>
    </a>
    <a class='right-next' id="next-link" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html" title="下一页 页">
    <div class="prev-next-info">
        <p class="prev-next-subtitle">下一页</p>
        <p class="prev-next-title">文章结构</p>
    </div>
    <i class="fas fa-angle-right"></i>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>